{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Analyze Errors and Explore Interpretability of Models\n",
    "\n",
    "This notebook demonstrates how to use the Responsible AI Widget's Error Analysis dashboard to understand a model trained on the superconductivity dataset.\n",
    "\n",
    "\n",
    "For more information on the dataset, please see the UCI repository:\n",
    "\n",
    "https://archive.ics.uci.edu/ml/datasets/Superconductivty+Data\n",
    "\n",
    "Paper citation:\n",
    "\n",
    "Hamidieh, Kam, A data-driven statistical model for predicting the critical temperature of a superconductor, Computational Materials Science, Volume 154, November 2018, Pages 346-354, [Web Link](https://doi.org/10.1016/j.commatsci.2018.07.052)\n",
    "\n",
    "\n",
    "The goal of this sample notebook is to predict the critical temperature with scikit-learn and explore model errors and explanations:\n",
    "\n",
    "1. Train an SVM classification model using Scikit-learn\n",
    "2. Run Interpret-Community's 'explain_model' globally and locally to generate model explanations.\n",
    "3. Visualize model errors and global and local explanations with the Error Analysis visualization dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Required Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install --upgrade interpret-community\n",
    "# %pip install --upgrade raiwidgets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import svm\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "\n",
    "from urllib.request import urlretrieve\n",
    "import requests\n",
    "import zipfile\n",
    "\n",
    "# Imports for SHAP MimicExplainer with LightGBM surrogate model\n",
    "from interpret.ext.blackbox import MimicExplainer\n",
    "from interpret.ext.glassbox import LGBMExplainableModel\n",
    "from interpret_community.common.constants import ModelTask\n",
    "\n",
    "from raiwidgets import ErrorAnalysisDashboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the dataset from UCI Repository: https://archive.ics.uci.edu/ml/datasets/Superconductivty+Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outdirname = 'superconduct'\n",
    "zipfilename = outdirname + '.zip'\n",
    "url = 'https://archive.ics.uci.edu/static/public/464/superconductivty+data.zip'\n",
    "# temporary workaround for UCI repository until website SSL certificate\n",
    "# is renewed with requests instead of urlretrieve\n",
    "# urlretrieve(url, zipfilename)\n",
    "content = requests.get(url, verify=False).content\n",
    "with open(zipfilename, mode='wb') as localfile:\n",
    "    localfile.write(content)\n",
    "with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "    unzip.extractall('.')\n",
    "df = pd.read_csv(r'./train.csv')\n",
    "y = df['critical_temp'].values\n",
    "X = df.drop(columns='critical_temp')\n",
    "feature_names = list(X.columns)\n",
    "X = X.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train and test\n",
    "X, _, y, _ = train_test_split(X, y, test_size=0.5, random_state=0)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_sample, _, y_test_sample, _ = train_test_split(X_test, y_test, test_size=0.7, random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"X_train size: \")\n",
    "print(X_train.shape)\n",
    "print(\"X_test size: \")\n",
    "print(X_test.shape)\n",
    "print(\"X_test_sample size: \")\n",
    "print(X_test_sample.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a SVM classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = svm.SVR(gamma=0.001, C=100., tol=0.1)\n",
    "model = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notice the model makes a fair amount of error\n",
    "print(\"average abs error on test dataset: \" + str(sum(abs(model.predict(X_test) - y_test))/y_test.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load simple ErrorAnalysis view without explanations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = model.predict(X_test)\n",
    "ErrorAnalysisDashboard(dataset=X_test, true_y=y_test, features=feature_names,\n",
    "                       pred_y=predictions, model_task='regression',\n",
    "                       max_depth=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a surrogate model to explain the original blackbox model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the LightGBM surrogate model using MimicExplaner\n",
    "model_task = ModelTask.Regression\n",
    "explainer = MimicExplainer(model, X_train, LGBMExplainableModel,\n",
    "                           augment_data=True, max_num_of_augmentations=10,\n",
    "                           features=feature_names, model_task=model_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate global explanations\n",
    "Explain overall model predictions (global explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
    "# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
    "global_explanation = explainer.explain_global(X_test_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print out a dictionary that holds the sorted feature importance names and values\n",
    "print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyze model errors and explanations using Error Analysis dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ErrorAnalysisDashboard(global_explanation, model, dataset=X_test,\n",
    "                       true_y=y_test_sample, true_y_dataset=y_test,\n",
    "                       model_task='regression', max_depth=3)"
   ]
  }
 ],
 "metadata": {
  "authors": [
   {
    "name": "mesameki"
   }
  ],
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
